{
  "cells": [
    {
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "%matplotlib inline"
      ],
      "cell_type": "code"
    },
    {
      "metadata": {},
      "source": [
        "\n# Gromov-Wasserstein Barycenter example\n\n\nThis example is designed to show how to use the Gromov-Wasserstein distance\ncomputation in POT.\n\n"
      ],
      "cell_type": "markdown"
    },
    {
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "# Author: Erwan Vautier <erwan.vautier@gmail.com>\n#         Nicolas Courty <ncourty@irisa.fr>\n#\n# License: MIT License\n\n\nimport numpy as np\nimport scipy as sp\n\nimport scipy.ndimage as spi\nimport matplotlib.pylab as pl\nfrom sklearn import manifold\nfrom sklearn.decomposition import PCA\n\nimport ot"
      ],
      "cell_type": "code"
    },
    {
      "metadata": {},
      "source": [
        "Smacof MDS\n----------\n\nThis function allows to find an embedding of points given a dissimilarity matrix\nthat will be given by the output of the algorithm\n\n"
      ],
      "cell_type": "markdown"
    },
    {
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "def smacof_mds(C, dim, max_iter=3000, eps=1e-9):\n    \"\"\"\n    Returns an interpolated point cloud following the dissimilarity matrix C\n    using SMACOF multidimensional scaling (MDS) in specific dimensionned\n    target space\n\n    Parameters\n    ----------\n    C : ndarray, shape (ns, ns)\n        dissimilarity matrix\n    dim : int\n          dimension of the targeted space\n    max_iter :  int\n        Maximum number of iterations of the SMACOF algorithm for a single run\n    eps : float\n        relative tolerance w.r.t stress to declare converge\n\n    Returns\n    -------\n    npos : ndarray, shape (R, dim)\n           Embedded coordinates of the interpolated point cloud (defined with\n           one isometry)\n    \"\"\"\n\n    rng = np.random.RandomState(seed=3)\n\n    mds = manifold.MDS(\n        dim,\n        max_iter=max_iter,\n        eps=1e-9,\n        dissimilarity='precomputed',\n        n_init=1)\n    pos = mds.fit(C).embedding_\n\n    nmds = manifold.MDS(\n        2,\n        max_iter=max_iter,\n        eps=1e-9,\n        dissimilarity=\"precomputed\",\n        random_state=rng,\n        n_init=1)\n    npos = nmds.fit_transform(C, init=pos)\n\n    return npos"
      ],
      "cell_type": "code"
    },
    {
      "metadata": {},
      "source": [
        "Data preparation\n----------------\n\nThe four distributions are constructed from 4 simple images\n\n"
      ],
      "cell_type": "markdown"
    },
    {
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "def im2mat(I):\n    \"\"\"Converts and image to matrix (one pixel per line)\"\"\"\n    return I.reshape((I.shape[0] * I.shape[1], I.shape[2]))\n\n\nsquare = spi.imread('../data/square.png').astype(np.float64)[:, :, 2] / 256\ncross = spi.imread('../data/cross.png').astype(np.float64)[:, :, 2] / 256\ntriangle = spi.imread('../data/triangle.png').astype(np.float64)[:, :, 2] / 256\nstar = spi.imread('../data/star.png').astype(np.float64)[:, :, 2] / 256\n\nshapes = [square, cross, triangle, star]\n\nS = 4\nxs = [[] for i in range(S)]\n\n\nfor nb in range(4):\n    for i in range(8):\n        for j in range(8):\n            if shapes[nb][i, j] < 0.95:\n                xs[nb].append([j, 8 - i])\n\nxs = np.array([np.array(xs[0]), np.array(xs[1]),\n               np.array(xs[2]), np.array(xs[3])])"
      ],
      "cell_type": "code"
    },
    {
      "metadata": {},
      "source": [
        "Barycenter computation\n----------------------\n\n"
      ],
      "cell_type": "markdown"
    },
    {
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "ns = [len(xs[s]) for s in range(S)]\nn_samples = 30\n\n\"\"\"Compute all distances matrices for the four shapes\"\"\"\nCs = [sp.spatial.distance.cdist(xs[s], xs[s]) for s in range(S)]\nCs = [cs / cs.max() for cs in Cs]\n\nps = [ot.unif(ns[s]) for s in range(S)]\np = ot.unif(n_samples)\n\n\nlambdast = [[float(i) / 3, float(3 - i) / 3] for i in [1, 2]]\n\nCt01 = [0 for i in range(2)]\nfor i in range(2):\n    Ct01[i] = ot.gromov.gromov_barycenters(n_samples, [Cs[0], Cs[1]],\n                                           [ps[0], ps[1]\n                                            ], p, lambdast[i], 'square_loss',  # 5e-4,\n                                           max_iter=100, tol=1e-3)\n\nCt02 = [0 for i in range(2)]\nfor i in range(2):\n    Ct02[i] = ot.gromov.gromov_barycenters(n_samples, [Cs[0], Cs[2]],\n                                           [ps[0], ps[2]\n                                            ], p, lambdast[i], 'square_loss',  # 5e-4,\n                                           max_iter=100, tol=1e-3)\n\nCt13 = [0 for i in range(2)]\nfor i in range(2):\n    Ct13[i] = ot.gromov.gromov_barycenters(n_samples, [Cs[1], Cs[3]],\n                                           [ps[1], ps[3]\n                                            ], p, lambdast[i], 'square_loss',  # 5e-4,\n                                           max_iter=100, tol=1e-3)\n\nCt23 = [0 for i in range(2)]\nfor i in range(2):\n    Ct23[i] = ot.gromov.gromov_barycenters(n_samples, [Cs[2], Cs[3]],\n                                           [ps[2], ps[3]\n                                            ], p, lambdast[i], 'square_loss',  # 5e-4,\n                                           max_iter=100, tol=1e-3)"
      ],
      "cell_type": "code"
    },
    {
      "metadata": {},
      "source": [
        "Visualization\n-------------\n\nThe PCA helps in getting consistency between the rotations\n\n"
      ],
      "cell_type": "markdown"
    },
    {
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "clf = PCA(n_components=2)\nnpos = [0, 0, 0, 0]\nnpos = [smacof_mds(Cs[s], 2) for s in range(S)]\n\nnpost01 = [0, 0]\nnpost01 = [smacof_mds(Ct01[s], 2) for s in range(2)]\nnpost01 = [clf.fit_transform(npost01[s]) for s in range(2)]\n\nnpost02 = [0, 0]\nnpost02 = [smacof_mds(Ct02[s], 2) for s in range(2)]\nnpost02 = [clf.fit_transform(npost02[s]) for s in range(2)]\n\nnpost13 = [0, 0]\nnpost13 = [smacof_mds(Ct13[s], 2) for s in range(2)]\nnpost13 = [clf.fit_transform(npost13[s]) for s in range(2)]\n\nnpost23 = [0, 0]\nnpost23 = [smacof_mds(Ct23[s], 2) for s in range(2)]\nnpost23 = [clf.fit_transform(npost23[s]) for s in range(2)]\n\n\nfig = pl.figure(figsize=(10, 10))\n\nax1 = pl.subplot2grid((4, 4), (0, 0))\npl.xlim((-1, 1))\npl.ylim((-1, 1))\nax1.scatter(npos[0][:, 0], npos[0][:, 1], color='r')\n\nax2 = pl.subplot2grid((4, 4), (0, 1))\npl.xlim((-1, 1))\npl.ylim((-1, 1))\nax2.scatter(npost01[1][:, 0], npost01[1][:, 1], color='b')\n\nax3 = pl.subplot2grid((4, 4), (0, 2))\npl.xlim((-1, 1))\npl.ylim((-1, 1))\nax3.scatter(npost01[0][:, 0], npost01[0][:, 1], color='b')\n\nax4 = pl.subplot2grid((4, 4), (0, 3))\npl.xlim((-1, 1))\npl.ylim((-1, 1))\nax4.scatter(npos[1][:, 0], npos[1][:, 1], color='r')\n\nax5 = pl.subplot2grid((4, 4), (1, 0))\npl.xlim((-1, 1))\npl.ylim((-1, 1))\nax5.scatter(npost02[1][:, 0], npost02[1][:, 1], color='b')\n\nax6 = pl.subplot2grid((4, 4), (1, 3))\npl.xlim((-1, 1))\npl.ylim((-1, 1))\nax6.scatter(npost13[1][:, 0], npost13[1][:, 1], color='b')\n\nax7 = pl.subplot2grid((4, 4), (2, 0))\npl.xlim((-1, 1))\npl.ylim((-1, 1))\nax7.scatter(npost02[0][:, 0], npost02[0][:, 1], color='b')\n\nax8 = pl.subplot2grid((4, 4), (2, 3))\npl.xlim((-1, 1))\npl.ylim((-1, 1))\nax8.scatter(npost13[0][:, 0], npost13[0][:, 1], color='b')\n\nax9 = pl.subplot2grid((4, 4), (3, 0))\npl.xlim((-1, 1))\npl.ylim((-1, 1))\nax9.scatter(npos[2][:, 0], npos[2][:, 1], color='r')\n\nax10 = pl.subplot2grid((4, 4), (3, 1))\npl.xlim((-1, 1))\npl.ylim((-1, 1))\nax10.scatter(npost23[1][:, 0], npost23[1][:, 1], color='b')\n\nax11 = pl.subplot2grid((4, 4), (3, 2))\npl.xlim((-1, 1))\npl.ylim((-1, 1))\nax11.scatter(npost23[0][:, 0], npost23[0][:, 1], color='b')\n\nax12 = pl.subplot2grid((4, 4), (3, 3))\npl.xlim((-1, 1))\npl.ylim((-1, 1))\nax12.scatter(npos[3][:, 0], npos[3][:, 1], color='r')"
      ],
      "cell_type": "code"
    }
  ],
  "metadata": {
    "language_info": {
      "name": "python",
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "nbconvert_exporter": "python",
      "version": "3.5.2",
      "pygments_lexer": "ipython3",
      "file_extension": ".py",
      "mimetype": "text/x-python"
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3",
      "language": "python"
    }
  },
  "nbformat_minor": 0,
  "nbformat": 4
}